# -*- coding: utf-8 -*-
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import scipy.signal

from ..misc import NeuroKitWarning
from .signal_interpolate import signal_interpolate


def signal_filter(
    signal,
    sampling_rate=1000,
    lowcut=None,
    highcut=None,
    method="butterworth",
    order=2,
    window_size="default",
    powerline=50,
    show=False,
):
    """**Signal filtering**

    Filter a signal using different methods such as "butterworth", "fir", "savgol" or "powerline"
    filters.

    Apply a lowpass (if "highcut" frequency is provided), highpass (if "lowcut" frequency is
    provided) or bandpass (if both are provided) filter to the signal.

    Parameters
    ----------
    signal : Union[list, np.array, pd.Series]
        The signal (i.e., a time series) in the form of a vector of values.
    sampling_rate : int
        The sampling frequency of the signal (in Hz, i.e., samples/second).
    lowcut : float
        Lower cutoff frequency in Hz. The default is ``None``.
    highcut : float
        Upper cutoff frequency in Hz. The default is ``None``.
    method : str
        Can be one of ``"butterworth"``, ``"fir"``, ``"bessel"`` or ``"savgol"``. Note that for
        Butterworth, the function uses the SOS method from :func:`.scipy.signal.sosfiltfilt`,
        recommended for general purpose filtering. One can also specify ``"butterworth_ba"`` for a
        more traditional and legacy method (often implemented in other software).
    order : int
        Only used if ``method`` is ``"butterworth"`` or ``"savgol"``. Order of the filter (default
        is 2).
    window_size : int
        Only used if ``method`` is ``"savgol"``. The length of the filter window (i.e. the number of
        coefficients). Must be an odd integer. If default, will be set to the sampling rate
        divided by 10 (101 if the sampling rate is 1000 Hz).
    powerline : int
        Only used if ``method`` is ``"powerline"``.
        The powerline frequency (normally 50 Hz or 60Hz).
    show : bool
        If ``True``, plot the filtered signal as an overlay of the original.

    See Also
    --------
    signal_detrend, signal_psd

    Returns
    -------
    array
        Vector containing the filtered signal.

    Examples
    --------
    .. ipython:: python

      import numpy as np
      import pandas as pd
      import neurokit2 as nk

      signal = nk.signal_simulate(duration=10, frequency=0.5) # Low freq
      signal += nk.signal_simulate(duration=10, frequency=5) # High freq

      # Visualize Lowpass Filtered Signal using Different Methods
      @savefig p_signal_filter1.png scale=100%
      fig1 = pd.DataFrame({"Raw": signal,
                         "Butter_2": nk.signal_filter(signal, highcut=3, method="butterworth",
                          order=2),
                         "Butter_2_BA": nk.signal_filter(signal, highcut=3,
                          method="butterworth_ba", order=2),
                         "Butter_5": nk.signal_filter(signal, highcut=3, method="butterworth",
                          order=5),
                         "Butter_5_BA": nk.signal_filter(signal, highcut=3,
                          method="butterworth_ba", order=5),
                         "Bessel_2": nk.signal_filter(signal, highcut=3, method="bessel", order=2),
                         "Bessel_5": nk.signal_filter(signal, highcut=3, method="bessel", order=5),
                         "FIR": nk.signal_filter(signal, highcut=3, method="fir")}).plot(subplots=True)
      @suppress
      plt.close()

    .. ipython:: python

      # Visualize Highpass Filtered Signal using Different Methods
      @savefig p_signal_filter2.png scale=100%
      fig2 = pd.DataFrame({"Raw": signal,
                          "Butter_2": nk.signal_filter(signal, lowcut=2, method="butterworth",
                          order=2),
                          "Butter_2_ba": nk.signal_filter(signal, lowcut=2,
                          method="butterworth_ba", order=2),
                          "Butter_5": nk.signal_filter(signal, lowcut=2, method="butterworth",
                          order=5),
                          "Butter_5_BA": nk.signal_filter(signal, lowcut=2,
                          method="butterworth_ba", order=5),
                          "Bessel_2": nk.signal_filter(signal, lowcut=2, method="bessel", order=2),
                          "Bessel_5": nk.signal_filter(signal, lowcut=2, method="bessel", order=5),
                          "FIR": nk.signal_filter(signal, lowcut=2, method="fir")}).plot(subplots=True)
      @suppress
      plt.close()

    .. ipython:: python

      # Using Bandpass Filtering in real-life scenarios
      # Simulate noisy respiratory signal
      original = nk.rsp_simulate(duration=30, method="breathmetrics", noise=0)
      signal = nk.signal_distort(original, noise_frequency=[0.1, 2, 10, 100], noise_amplitude=1,
                                powerline_amplitude=1)

      # Bandpass between 10 and 30 breaths per minute (respiratory rate range)
      @savefig p_signal_filter3.png scale=100%
      fig3 = pd.DataFrame({"Raw": signal,
                           "Butter_2": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                        method="butterworth", order=2),
                           "Butter_2_BA": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                           method="butterworth_ba", order=2),
                           "Butter_5": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                        method="butterworth", order=5),
                           "Butter_5_BA": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                           method="butterworth_ba", order=5),
                           "Bessel_2": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                        method="bessel", order=2),
                           "Bessel_5": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                        method="bessel", order=5),
                           "FIR": nk.signal_filter(signal, lowcut=10/60, highcut=30/60,
                                                   method="fir"),
                           "Savgol": nk.signal_filter(signal, method="savgol")}).plot(subplots=True)
      @suppress
      plt.close()

    """
    method = method.lower()

    signal_sanitized, missing = _signal_filter_missing(signal)

    if method in ["sg", "savgol", "savitzky-golay"]:
        filtered = _signal_filter_savgol(signal_sanitized, sampling_rate, order, window_size=window_size)
    elif method in ["powerline"]:
        filtered = _signal_filter_powerline(signal_sanitized, sampling_rate, powerline)
    else:

        # Sanity checks
        if lowcut is None and highcut is None:
            raise ValueError("NeuroKit error: signal_filter(): you need to specify a 'lowcut' or a 'highcut'.")

        if method in ["butter", "butterworth"]:
            filtered = _signal_filter_butterworth(signal_sanitized, sampling_rate, lowcut, highcut, order)
        elif method in ["butter_ba", "butterworth_ba"]:
            filtered = _signal_filter_butterworth_ba(signal_sanitized, sampling_rate, lowcut, highcut, order)
        elif method in ["butter_zi", "butterworth_zi"]:
            filtered = _signal_filter_butterworth_zi(signal_sanitized, sampling_rate, lowcut, highcut, order)
        elif method in ["bessel"]:
            filtered = _signal_filter_bessel(signal_sanitized, sampling_rate, lowcut, highcut, order)
        elif method in ["fir"]:
            filtered = _signal_filter_fir(signal_sanitized, sampling_rate, lowcut, highcut, window_size=window_size)
        else:
            raise ValueError(
                "NeuroKit error: signal_filter(): 'method' should be",
                " one of 'butterworth', 'butterworth_ba', 'butterworth_zi', 'bessel',",
                " 'savgol' or 'fir'.",
            )

    filtered[missing] = np.nan

    if show is True:
        plt.plot(signal, color="lightgrey")
        plt.plot(filtered, color="red", alpha=0.9)

    return filtered


# =============================================================================
# Savitzky-Golay (savgol)
# =============================================================================


def _signal_filter_savgol(signal, sampling_rate=1000, order=2, window_size="default"):
    """Filter a signal using the Savitzky-Golay method.

    Default window size is chosen based on `Sadeghi, M., & Behnia, F. (2018). Optimum window length of
    Savitzky-Golay filters with arbitrary order. arXiv preprint arXiv:1808.10489.
    <https://arxiv.org/ftp/arxiv/papers/1808/1808.10489.pdf>`_.

    """
    window_size = _signal_filter_windowsize(window_size=window_size, sampling_rate=sampling_rate)
    if window_size % 2 == 0:
        window_size += 1  # Make sure it's odd

    filtered = scipy.signal.savgol_filter(signal, window_length=int(window_size), polyorder=order)
    return filtered


# =============================================================================
# FIR
# =============================================================================
def _signal_filter_fir(signal, sampling_rate=1000, lowcut=None, highcut=None, window_size="default"):
    """Filter a signal using a FIR filter."""
    try:
        import mne
    except ImportError:
        raise ImportError(
            "NeuroKit error: signal_filter(): the 'mne' module is required for this method to run. ",
            "Please install it first (`pip install mne`).",
        )

    if isinstance(window_size, str):
        window_size = "auto"

    filtered = mne.filter.filter_data(
        signal,
        sfreq=sampling_rate,
        l_freq=lowcut,
        h_freq=highcut,
        method="fir",
        fir_window="hamming",
        filter_length=window_size,
        l_trans_bandwidth="auto",
        h_trans_bandwidth="auto",
        phase="zero-double",
        fir_design="firwin",
        pad="reflect_limited",
        verbose=False,
    )
    return filtered


# =============================================================================
# Butterworth
# =============================================================================


def _signal_filter_butterworth(signal, sampling_rate=1000, lowcut=None, highcut=None, order=5):
    """Filter a signal using IIR Butterworth SOS method."""
    freqs, filter_type = _signal_filter_sanitize(lowcut=lowcut, highcut=highcut, sampling_rate=sampling_rate)
    sos = scipy.signal.butter(order, freqs, btype=filter_type, output="sos", fs=sampling_rate)
    filtered = scipy.signal.sosfiltfilt(sos, signal)
    return filtered


def _signal_filter_butterworth_ba(signal, sampling_rate=1000, lowcut=None, highcut=None, order=5):
    """Filter a signal using IIR Butterworth B/A method."""
    # Get coefficients
    freqs, filter_type = _signal_filter_sanitize(lowcut=lowcut, highcut=highcut, sampling_rate=sampling_rate)

    b, a = scipy.signal.butter(order, freqs, btype=filter_type, output="ba", fs=sampling_rate)
    try:
        filtered = scipy.signal.filtfilt(b, a, signal, method="gust")
    except ValueError:
        filtered = scipy.signal.filtfilt(b, a, signal, method="pad")

    return filtered


def _signal_filter_butterworth_zi(signal, sampling_rate=1000, lowcut=None, highcut=None, order=5):
    """Filter a signal using IIR Butterworth SOS method, given initial state (zi)."""

    freqs, filter_type = _signal_filter_sanitize(lowcut=lowcut, highcut=highcut, sampling_rate=sampling_rate)

    sos = scipy.signal.butter(order, freqs, btype=filter_type, output="sos", fs=sampling_rate)

    zi_coeff = scipy.signal.sosfilt_zi(sos)
    zi = zi_coeff * np.mean(signal)
    # Filter data along one dimension using cascaded second-order sections.
    return scipy.signal.sosfilt(sos, signal, zi=zi)[0]


# =============================================================================
# Bessel
# =============================================================================


def _signal_filter_bessel(signal, sampling_rate=1000, lowcut=None, highcut=None, order=5):
    freqs, filter_type = _signal_filter_sanitize(lowcut=lowcut, highcut=highcut, sampling_rate=sampling_rate)

    sos = scipy.signal.bessel(order, freqs, btype=filter_type, output="sos", fs=sampling_rate)
    filtered = scipy.signal.sosfiltfilt(sos, signal)
    return filtered


# =============================================================================
# Powerline
# =============================================================================


def _signal_filter_powerline(signal, sampling_rate, powerline=50):
    """Filter out 50 Hz powerline noise by smoothing the signal with a moving average kernel with the width of one
    period of 50Hz."""

    if sampling_rate >= 100:
        b = np.ones(int(sampling_rate / powerline))
    else:
        b = np.ones(2)
    a = [len(b)]
    y = scipy.signal.filtfilt(b, a, signal, method="pad")
    return y


# =============================================================================
# Utility
# =============================================================================
def _signal_filter_sanitize(lowcut=None, highcut=None, sampling_rate=1000, normalize=False):

    # Sanity checks
    if lowcut is not None or highcut is not None:
        if sampling_rate <= 2 * np.nanmax(np.array([lowcut, highcut], dtype=np.float64)):
            warn(
                "The sampling rate is too low. Sampling rate"
                " must exceed the Nyquist rate to avoid aliasing problem."
                f" In this analysis, the sampling rate has to be higher than {2 * highcut} Hz",
                category=NeuroKitWarning,
            )

    # Replace 0 by none
    if lowcut is not None and lowcut == 0:
        lowcut = None
    if highcut is not None and highcut == 0:
        highcut = None

    # Format
    if lowcut is not None and highcut is not None:
        if lowcut > highcut:
            filter_type = "bandstop"
        else:
            filter_type = "bandpass"
        # pass frequencies in order of lowest to highest to the scipy filter
        freqs = list(np.sort([lowcut, highcut]))
    elif lowcut is not None:
        freqs = lowcut
        filter_type = "highpass"
    elif highcut is not None:
        freqs = highcut
        filter_type = "lowpass"

    # Normalize frequency to Nyquist Frequency (Fs/2).
    # However, no need to normalize if `fs` argument is provided to the scipy filter
    if normalize is True:
        freqs = np.array(freqs) / (sampling_rate / 2)

    return freqs, filter_type


def _signal_filter_windowsize(window_size="default", sampling_rate=1000):
    if isinstance(window_size, str):
        window_size = int(np.round(sampling_rate / 3))
        if (window_size % 2) == 0:
            window_size + 1  # pylint: disable=W0104
    return window_size


def _signal_filter_missing(signal):
    """Interpolate missing data and save the indices of the missing data."""
    missing = np.where(np.isnan(signal))[0]
    if len(missing) > 0:
        return signal_interpolate(signal, method="linear"), missing
    else:
        return signal, missing
